# 加载数据文件

import pandas as pd
import numpy as np

PATH = './data/data.xlsx'
np.random.seed(123)

def load(train_scale = 0.8):
    df = pd.read_excel(PATH)
    data = np.array(df)

    dataset = data[:, 2:]
    group = data[:, 0]

    permutation = np.random.permutation(dataset.shape[0])
    shuffled_dataset = dataset[permutation, :]
    shuffled_group = group[permutation]

    train_cnt = int(len(shuffled_dataset)*train_scale)
    train_dataset, train_group = shuffled_dataset[0:train_cnt, :], shuffled_group[0:train_cnt]
    test_dataset, test_group = shuffled_dataset[train_cnt:, :], shuffled_group[train_cnt:]

    return train_dataset, train_group, test_dataset, test_group


if __name__ == '__main__':
    train_dataset, train_group, test_dataset, test_group = load()
    print(train_dataset.shape)
    print(train_group.shape)
    print(test_dataset.shape)
    print(test_group.shape)
    print(test_group)